// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#pragma once

#include <optional>

#include "ttnn/tensor/tensor.hpp"
#include "ttnn/run_operation.hpp"

namespace ttnn::operations::reduction {

struct ArgMax {
    const tt::tt_metal::DataType output_dtype;
    const std::optional<int> dim;
    const bool keepdim;
    const std::optional<CoreRangeSet> sub_core_grids;
    const bool use_multicore;
    const tt::tt_metal::MemoryConfig output_mem_config;

    /*
     * Generates the output shape for the reduction operation.
     * The output shape is generated by iterating over the input shape and adjusting
     * the output shape for keepdim.
     * @param input_tensor The input tensor on which reduction is performed.
     * @return The output shape.
     */
    ttnn::SmallVector<uint32_t> get_output_shape(const Tensor& input_tensor) const;
    void validate_with_output_tensors(
        const std::vector<Tensor>& input_tensors, const std::vector<std::optional<Tensor>>& output_tensors) const;
    std::vector<TensorSpec> compute_output_specs(
        const std::vector<Tensor>& input_tensors, const std::vector<std::optional<Tensor>>& output_tensors) const;
    std::vector<Tensor> create_output_tensors(
        const std::vector<Tensor>& input_tensors, const std::vector<std::optional<Tensor>>& output_tensors) const;
    tt::tt_metal::operation::ProgramWithCallbacks create_program(
        const std::vector<Tensor>& input_tensors, std::vector<Tensor>& output_tensors) const;
};

}  // namespace ttnn::operations::reduction
